#include <zMat.hpp>
#include <UnitTest++.h>

using namespace zzz;

SUITE(zSparseMatrixTest)
{
  TEST(SetAndGet)
  {
    zSparseMatrix<double> mat(1000,1000);
    CHECK_EQUAL(1000, mat.Size(0));
    CHECK_EQUAL(1000, mat.Size(1));

    CHECK(!mat.CheckExist(10,100));
    CHECK_EQUAL(0, mat(10,100));
    mat.AddData(10,100,1);
    CHECK(mat.CheckExist(10,100));
    CHECK_EQUAL(1, mat(10,100));
    mat.MustGet(100,100)=2;
    CHECK_EQUAL(2, mat(100,100));

    CHECK_EQUAL(2, mat.DataSize());
  }
  TEST(AssignFromMatrix)
  {
    zMatrix<double> mat(Zerosd(100,100));
    mat(50,50)=10;
    mat(40,60)=20;
    zSparseMatrix<double> smat(mat);
    CHECK_EQUAL(2, smat.DataSize());
    CHECK_EQUAL(10, smat(50,50));
    CHECK_EQUAL(20, smat(40,60));
  }
  TEST(ATA)
  {
    zMatrix<double> a(Zerosd(4,4));
    // 1 2 0 4
    // 3 0 3 0
    // 0 0 1 2
    a(0,0)=1;
    a(0,1)=2;
    a(0,3)=4;
    a(1,0)=3;
    a(1,2)=3;
    a(2,2)=1;
    a(2,3)=2;
    zMatrix<double> ata(Trans(a)*a);

    zSparseMatrix<double> sa(a);
    zMatrix<double> sata(Trans(sa)*sa);
    zSparseMatrix<double> ssata(ATA(sa));

    CHECK_EQUAL(14, ssata.DataSize());
    CHECK(ata==sata);
    CHECK(sata==ssata);
    CHECK(Trans(ssata)==ssata);
  }
}